import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

__all__ = ['GraphMixer']


class GraphMixer(nn.Module):
    def __init__(self, num_nodes, msg_dim, time_dim=100, model_dim=100, token_expansion=0.5, channel_expansion=4,
                 num_layers=2, dropout=0.5, num_neighbours=10, pooling_method="mean", **kwargs):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_neighbours = num_neighbours

        self.mlp_mixer = MLPMixer(num_nodes, msg_dim, time_dim, model_dim, token_expansion, channel_expansion,
                                  num_layers, dropout, pooling_method)
        self.link_classifier = LinkClassifier(dim=model_dim)
        self.loss_fn = nn.BCEWithLogitsLoss(reduction='mean')
        self.device = None
        self.reset_parameters()

    def reset_parameters(self):
        self.mlp_mixer.reset_parameters()
        self.link_classifier.reset_parameters()

    def loss(self, y, logits, **kwargs):
        return self.loss_fn(logits, y.float())

    def forward(self, src, dst, t, msg, mask, **kwargs):
        z = self.mlp_mixer(t, msg, mask)
        logits = self.link_classifier(z[src], z[dst])
        model_outputs = {"logits": logits.squeeze()}
        return model_outputs

    def set_device(self, device):
        self.device = device
        self.mlp_mixer.to(device)
        self.link_classifier.to(device)
        self.to(device)

    def format_input(self, data):
        src, dst, t, msg = data.src, data.dst, data.t, data.msg
        num_edges = src.shape[0]
        mask = num_edges * torch.ones((self.num_nodes, self.num_neighbours))
        for i in range(self.num_nodes):
            links = torch.nonzero(torch.logical_or(src == i, dst == i))
            links = torch.flip(links[:, -1], [0])
            num_links = links.shape[0]
            if num_links <= self.num_neighbours:
                mask[i, :] = F.pad(links, pad=(0, self.num_neighbours-num_links), value=num_edges)
            else:
                mask[i, :] = links[0: self.num_neighbours]
        mask = mask.long()
        data_inputs = {"src": src.long().to(self.device),
                       "dst": dst.long().to(self.device),
                       "t": t.float().to(self.device),
                       "msg": msg.float().to(self.device),
                       "mask": mask.long().to(self.device)}
        data_targets = {"y": data.y.to(self.device)}
        return data_inputs, data_targets

    def format_output(self, data_targets, model_outputs):
        targets = data_targets["y"].detach().cpu()
        predictions = model_outputs["logits"].detach().cpu().sigmoid()
        return targets, predictions


class MLPMixer(nn.Module):
    """
    Input : [ batch_size, graph_size, edge_dims+time_dims]
    Output: [ batch_size, graph_size, output_dims]
    """
    def __init__(self, num_nodes, msg_dim, time_dim, model_dim, token_expansion, channel_expansion, 
                 num_layers, dropout, pooling_method, **kwargs):
        super().__init__()
        
        self.feature_encoder = FeatureEncoder(msg_dim, time_dim, model_dim)
        self.mixer_blocks = torch.nn.ModuleList()
        for i in range(num_layers):
            self.mixer_blocks.append(MixerBlock(token_dim=num_nodes,
                                                channel_dim=model_dim, 
                                                token_expansion=token_expansion, 
                                                channel_expansion=channel_expansion, 
                                                dropout=dropout))
        self.norm = nn.LayerNorm(model_dim)
        self.pool = PoolingLayer(pooling_method)
        self.linear = nn.Linear(model_dim, model_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.feature_encoder.reset_parameters()
        for block in self.mixer_blocks:
            block.reset_parameters()
        self.norm.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, t, msg, mask):
        edge_time_feats = self.feature_encoder(t, msg)
        edge_time_feats = torch.cat([edge_time_feats, torch.zeros_like(edge_time_feats[0:1, :])])
        x = edge_time_feats[mask]
        for block in self.mixer_blocks:
            x = block(x)
        x = self.norm(x)
        x = self.pool(x)
        x = self.linear(x)
        return x


class PoolingLayer(nn.Module):
    def __init__(self, pooling_method):
        super().__init__()
        self.method = pooling_method

    def forward(self, x):
        if self.method == 'mean':
            return torch.mean(x, dim=1)
        elif self.method == 'max':
            return torch.max(x, dim=1)


class MixerBlock(nn.Module):
    def __init__(self, token_dim, channel_dim, token_expansion=0.5, channel_expansion=4.0, dropout=0.0):
        super().__init__()
        self.token_norm = nn.LayerNorm(channel_dim)
        self.token_forward = FeedForward(token_dim, token_expansion, dropout)
        self.channel_norm = nn.LayerNorm(channel_dim)
        self.channel_forward = FeedForward(channel_dim, channel_expansion, dropout)

    def reset_parameters(self):
        self.token_norm.reset_parameters()
        self.token_forward.reset_parameters()
        self.channel_norm.reset_parameters()
        self.channel_forward.reset_parameters()

    def token_mixer(self, x):
        x = self.token_norm(x).permute(2, 1, 0)
        x = self.token_forward(x).permute(2, 1, 0)
        return x

    def channel_mixer(self, x):
        x = self.channel_norm(x)
        x = self.channel_forward(x)
        return x

    def forward(self, x):
        x = x + self.token_mixer(x)
        x = x + self.channel_mixer(x)
        return x


class FeatureEncoder(nn.Module):
    def __init__(self, msg_dim, time_dim, model_dim):
        super().__init__()
        self.time_encoder = TimeEncoder(time_dim)
        self.linear = nn.Linear(time_dim + msg_dim, model_dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.time_encoder.reset_parameters()
        self.linear.reset_parameters()

    def forward(self, t, msg):
        edge_time_feats = self.time_encoder(t)
        x = torch.cat([msg, edge_time_feats], dim=1)
        x = self.linear(x)
        return x


class TimeEncoder(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.linear = nn.Linear(1, dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.linear.weight = nn.Parameter(
            (torch.from_numpy(1 / 10 ** np.linspace(0, 9, self.dim, dtype=np.float32))).reshape(self.dim, -1))
        self.linear.bias = nn.Parameter(torch.zeros(self.dim))
        self.linear.weight.requires_grad = False
        self.linear.bias.requires_grad = False

    @torch.no_grad()
    def forward(self, t):
        output = torch.cos(self.linear(t.reshape((-1, 1))))
        return output


class FeedForward(nn.Module):
    def __init__(self, dim, expansion_factor, dropout=0.0):
        super().__init__()
        self.dim = dim
        self.expansion_factor = expansion_factor
        self.dropout = dropout
        self.linear1 = nn.Linear(dim, int(expansion_factor * dim))
        self.linear2 = nn.Linear(int(expansion_factor * dim), dim)
        self.reset_parameters()

    def reset_parameters(self):
        self.linear1.reset_parameters()
        self.linear2.reset_parameters()

    def forward(self, x):
        x = self.linear1(x)
        x = F.gelu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.linear2(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return x


class LinkClassifier(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear_src = nn.Linear(dim, dim)
        self.linear_dst = nn.Linear(dim, dim)
        self.linear_fin = nn.Linear(dim, 1)
        
    def reset_parameters(self):
        self.linear_src.reset_parameters()
        self.linear_dst.reset_parameters()
        self.linear_fin.reset_parameters()

    def forward(self, z_src, z_dst):
        h = self.linear_src(z_src) + self.linear_dst(z_dst)
        h = h.relu()
        logits = self.linear_fin(h).squeeze()
        return logits









